# 特征聚合方式

import torch
from torch_geometric.utils import scatter


data = torch.tensor([[1,2], [3,4], [5,6]])
index = torch.tensor([0, 0, 1])

# 按index求和
fea = scatter(data, index, reduce='sum')[index]
# fea = scatter(data, index, reduce='mean')
# fea = scatter(data, index, reduce='min')
# fea = scatter(data, index, reduce='max')
print(fea)
# 输出: tensor([[4, 6], [5, 6]])
print(fea - fea.max())
pw = torch.exp(fea - fea.max())
print(pw)

